{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.random.seed(42)\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.patches as mpatches\n",
    "import os\n",
    "import pandas as pd\n",
    "import time\n",
    "\n",
    "from processing.load_datasets import load_datasets\n",
    "from processing.configs import SINGLE_COLOR, COLORS_4\n",
    "\n",
    "from sklearn.manifold import TSNE\n",
    "from umap import UMAP\n",
    "from trimap import TRIMAP\n",
    "from pacmap import PaCMAP"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load datasets\n",
    "datasets_names = [\"rts\"]\n",
    "datasets = load_datasets(datasets_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dict = {}\n",
    "N_SAMPLE = 10000\n",
    "\n",
    "EMBEDDINGS_FOLDER = \"embeddings/params/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## tSNE - Study of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "perplexities = [5, 10, 50, 100]\n",
    "\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "    if N_SAMPLE > 0 and N_SAMPLE < X.shape[0]:\n",
    "        rng = np.random.default_rng(42)\n",
    "        X = X[rng.choice(X.shape[0], N_SAMPLE, replace=False)]\n",
    "        \n",
    "    embeddings = []\n",
    "    for perplexity in perplexities:\n",
    "        save_folder = os.path.join(EMBEDDINGS_FOLDER, dataset)\n",
    "        save_path = os.path.join(save_folder, f\"{dataset}_tsne_p{perplexity}.npy\")\n",
    "        if os.path.exists(save_path):\n",
    "            print(f\"Loading existing t-SNE embedding for perplexity={perplexity}...\")\n",
    "            X_low = np.load(save_path)\n",
    "        else:\n",
    "            print(f\"Running t-SNE with perplexity={perplexity}...\")\n",
    "            X_low = TSNE(perplexity=perplexity).fit_transform(X)\n",
    "        \n",
    "        embeddings.append(X_low)\n",
    "    \n",
    "        if not os.path.exists(save_folder):\n",
    "            os.makedirs(save_folder)\n",
    "        np.save(save_path, X_low)\n",
    "\n",
    "    embeddings_dict[f\"{dataset}_tsne\"] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    key = f\"{dataset}_tsne\"\n",
    "    fig, ax = plt.subplots(1, 4, figsize=(20, 5))\n",
    "    for i, perplexity in enumerate(perplexities):\n",
    "        ax[i].scatter(embeddings_dict[key][i][:, 0], embeddings_dict[key][i][:, 1], s=0.1, color=SINGLE_COLOR)\n",
    "        ax[i].set_title(f'Perplexity={perplexity}', fontsize=16, fontweight='bold')\n",
    "        ax[i].set_xticks([])\n",
    "        ax[i].set_yticks([])\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"images/params/tsne_params_{dataset}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UMAP - Study of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_neighbors = [5, 30, 50, 100]\n",
    "min_distances = [0.1, 0.5, 0.99]\n",
    "\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "    if N_SAMPLE > 0 and N_SAMPLE < X.shape[0]:\n",
    "        rng = np.random.default_rng(42)\n",
    "        X = X[rng.choice(X.shape[0], N_SAMPLE, replace=False)]\n",
    "\n",
    "    embeddings = []\n",
    "    for d in min_distances:\n",
    "        for n in n_neighbors:\n",
    "            save_folder = os.path.join(EMBEDDINGS_FOLDER, dataset)\n",
    "            save_path = os.path.join(save_folder, f\"{dataset}_umap_n{n}_d{d}.npy\")\n",
    "\n",
    "            if os.path.exists(save_path):\n",
    "                print(f\"Loading existing UMAP embedding for n_neighbors={n} and min_dist={d}...\")\n",
    "                X_low = np.load(save_path)\n",
    "            else:\n",
    "                print(f\"Running UMAP with n_neighbors={n} and min_dist={d}...\")\n",
    "                X_low = UMAP(n_neighbors=n, min_dist=d).fit_transform(X)\n",
    "            embeddings.append(X_low)\n",
    "\n",
    "            if not os.path.exists(save_folder):\n",
    "                os.makedirs(save_folder)\n",
    "            np.save(save_path, X_low)\n",
    "            \n",
    "    embeddings_dict[f\"{dataset}_umap\"] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    key = f\"{dataset}_umap\"\n",
    "    fig, axs = plt.subplots(len(min_distances), len(n_neighbors), figsize=(20, 15))\n",
    "    counter = 0\n",
    "    for i, min_dist in enumerate(min_distances):\n",
    "        for j, n in enumerate(n_neighbors):\n",
    "            axs[i,j].scatter(embeddings_dict[key][counter][:, 0], embeddings_dict[key][counter][:, 1], s=0.1, color=SINGLE_COLOR)\n",
    "            #axs[i,j].set_title(f'n_neighbors={n}, min_dist={min_dist}', fontsize=16, fontweight='bold')\n",
    "            axs[i,j].set_xticks([])\n",
    "            axs[i,j].set_yticks([])\n",
    "            counter += 1\n",
    "            \n",
    "    for i, min_dist in enumerate(min_distances):\n",
    "        axs[i,0].set_ylabel(f'min_dist={min_dist}', fontsize=16, fontweight='bold')\n",
    "    for j, n in enumerate(n_neighbors):\n",
    "        axs[0,j].set_title(f'n_neighbors={n}', fontsize=16, fontweight='bold')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"images/params/umap_params_{dataset}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TRIMAP - Study of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_values = [5, 30, 50, 100]\n",
    "\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "    if N_SAMPLE > 0 and N_SAMPLE < X.shape[0]:\n",
    "        rng = np.random.default_rng(42)\n",
    "        X = X[rng.choice(X.shape[0], N_SAMPLE, replace=False)]\n",
    "\n",
    "    embeddings = []\n",
    "    for c in c_values:\n",
    "        save_folder = os.path.join(EMBEDDINGS_FOLDER, dataset)\n",
    "        save_path = os.path.join(save_folder, f\"{dataset}_trimap_c{c}.npy\")\n",
    "        if os.path.exists(save_path):\n",
    "            print(f\"Loading existing TRIMAP embedding for n_inliers={2*c}, n_outliers={c}, n_random={c}...\")\n",
    "            X_low = np.load(save_path)\n",
    "        else:\n",
    "            print(f\"Running TRIMAP with n_inliers={2*c}, n_outliers={c}, n_random={c}...\")\n",
    "            X_low = TRIMAP(n_inliers=2*c, n_outliers=c, n_random=c).fit_transform(X)\n",
    "\n",
    "        embeddings.append(X_low)\n",
    "\n",
    "        if not os.path.exists(save_folder):\n",
    "            os.makedirs(save_folder)\n",
    "        np.save(save_path, X_low)\n",
    "\n",
    "    embeddings_dict[f\"{dataset}_trimap\"] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    key = f\"{dataset}_trimap\"\n",
    "    fig, axs = plt.subplots(1, 4, figsize=(20, 5))\n",
    "    for i, c in enumerate(c_values):\n",
    "        axs[i].scatter(embeddings_dict[key][i][:, 0], embeddings_dict[key][i][:, 1], s=0.1, color=SINGLE_COLOR)\n",
    "        axs[i].set_title(f'triplets={c}(2,1,1)', fontsize=16, fontweight='bold')\n",
    "        axs[i].set_xticks([])\n",
    "        axs[i].set_yticks([])\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"images/params/trimap_params_{dataset}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PACMAP - Study of parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mn_ratios = [0.1, 0.5, 1, 5]\n",
    "fp_ratios = [0.5, 2, 5, 10]\n",
    "\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "    if N_SAMPLE > 0 and N_SAMPLE < X.shape[0]:\n",
    "        rng = np.random.default_rng(42)\n",
    "        X = X[rng.choice(X.shape[0], N_SAMPLE, replace=False)]\n",
    "\n",
    "    embeddings = []\n",
    "    for mn in mn_ratios:\n",
    "        for fp in fp_ratios:\n",
    "            save_folder = os.path.join(EMBEDDINGS_FOLDER, dataset)\n",
    "            save_path = os.path.join(save_folder, f\"{dataset}_pacmap_mn{mn}_fp{fp}.npy\")\n",
    "\n",
    "            if os.path.exists(save_path):\n",
    "                print(f\"Loading existing PaCMAP embedding for MN_ratio={mn} and FP_ratio={fp}...\")\n",
    "                X_low = np.load(save_path)\n",
    "            else:\n",
    "                print(f\"Running PaCMAP with MN_ratio={mn} and FP_ratio={fp}...\")\n",
    "                X_low = PaCMAP(MN_ratio=mn, FP_ratio=fp).fit_transform(X)\n",
    "            \n",
    "            embeddings.append(X_low)\n",
    "\n",
    "            if not os.path.exists(save_folder):\n",
    "                os.makedirs(save_folder)\n",
    "            np.save(save_path, X_low)\n",
    "            \n",
    "    embeddings_dict[f\"{dataset}_pacmap_ratio\"] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    key = f\"{dataset}_pacmap_ratio\"\n",
    "    fig, axs = plt.subplots(len(mn_ratios), len(fp_ratios), figsize=(20, 20))\n",
    "    counter = 0\n",
    "    for i, mn in enumerate(mn_ratios):\n",
    "        for j, fp in enumerate(fp_ratios):\n",
    "            axs[i,j].scatter(embeddings_dict[key][counter][:, 0], embeddings_dict[key][counter][:, 1], s=0.1, color=SINGLE_COLOR)\n",
    "            axs[i,j].set_xticks([])\n",
    "            axs[i,j].set_yticks([])\n",
    "            counter += 1\n",
    "            \n",
    "    for i, mn in enumerate(mn_ratios):\n",
    "        axs[i,0].set_ylabel(f'MN_ratio={mn}', fontsize=16, fontweight='bold')\n",
    "    for j, fp in enumerate(fp_ratios):\n",
    "        axs[0,j].set_title(f'FP_ratio={fp}', fontsize=16, fontweight='bold')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"images/params/pacmap_params_ratios_{dataset}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_values = [5, 30, 50, 100]\n",
    "\n",
    "for dataset in datasets_names:\n",
    "    X = datasets[dataset]\n",
    "    if N_SAMPLE > 0 and N_SAMPLE < X.shape[0]:\n",
    "        rng = np.random.default_rng(42)\n",
    "        X = X[rng.choice(X.shape[0], N_SAMPLE, replace=False)]\n",
    "\n",
    "    embeddings = []\n",
    "    for n in n_values:\n",
    "        save_folder = os.path.join(EMBEDDINGS_FOLDER, dataset)\n",
    "        save_path = os.path.join(save_folder, f\"{dataset}_pacmap_n{n}.npy\")        \n",
    "        if os.path.exists(save_path):\n",
    "            print(f\"Loading existing PaCMAP embedding for n_neighbors={n}...\")\n",
    "            X_low = np.load(save_path)\n",
    "        else:\n",
    "            print(f\"Running PaCMAP with n_neighbors={n}...\")\n",
    "            X_low = PaCMAP(n_neighbors=n).fit_transform(X)\n",
    "        \n",
    "        embeddings.append(X_low)\n",
    "\n",
    "        if not os.path.exists(save_folder):\n",
    "            os.makedirs(save_folder)\n",
    "        np.save(save_path, X_low)\n",
    "        \n",
    "    embeddings_dict[f\"{dataset}_pacmap_n\"] = embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for dataset in datasets_names:\n",
    "    key = f\"{dataset}_pacmap_n\"\n",
    "    fig, axs = plt.subplots(1, 4, figsize=(20, 5))\n",
    "    for i, n in enumerate(n_values):\n",
    "        axs[i].scatter(embeddings_dict[key][i][:, 0], embeddings_dict[key][i][:, 1], s=0.1, color=SINGLE_COLOR)\n",
    "        axs[i].set_title(f'n_neighbors={n}', fontsize=16, fontweight='bold')\n",
    "        axs[i].set_xticks([])\n",
    "        axs[i].set_yticks([])\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"images/params/pacmap_params_n_{dataset}.png\", dpi=300, bbox_inches='tight')\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parameters stability evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute Procustes distance between embeddings\n",
    "from scipy.spatial import procrustes\n",
    "\n",
    "def procrustes_distance(X, Y):\n",
    "    _, _, dist = procrustes(X, Y)\n",
    "    return dist\n",
    "\n",
    "# Compute pairwise Procrustes distances between embeddings\n",
    "def compute_procrustes_distances(embeddings):\n",
    "    n = len(embeddings)\n",
    "    distances = np.zeros((n, n))\n",
    "    for i in range(n):\n",
    "        for j in range(n):\n",
    "            distances[i, j] = procrustes_distance(embeddings[i], embeddings[j])\n",
    "    return distances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_procustes_d = {key:compute_procrustes_distances(embeddings) for key, embeddings in embeddings_dict.items()}\n",
    "embeddings_procustes_d_eval = {key:(np.mean(distances), np.std(distances)) for key, distances in embeddings_procustes_d.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_procustes_d_eval = pd.DataFrame(embeddings_procustes_d_eval, index=[\"mean\", \"std\"]).T.reset_index().rename(columns={\"index\":\"dataset_algo\"})\n",
    "embeddings_procustes_d_eval[\"dataset\"] = embeddings_procustes_d_eval[\"dataset_algo\"].apply(lambda x: x.split(\"_\")[0])\n",
    "embeddings_procustes_d_eval[\"algo\"] = embeddings_procustes_d_eval[\"dataset_algo\"].apply(lambda x: x.split(\"_\")[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "barwidth = 0.2\n",
    "algo_names = [\"tsne\", \"umap\", \"trimap\", \"pacmap\"]\n",
    "n_datasets = len(datasets_names)\n",
    "n_algos = len(algo_names)\n",
    "\n",
    "legend_patches = [mpatches.Patch(color=COLORS_4(i), label=algo.upper()) for i, algo in enumerate(algo_names)]\n",
    "\n",
    "# Plot neighborhood preservation on single plot\n",
    "fig, ax = plt.subplots(figsize=(10, 4))\n",
    "for i, dataset in enumerate(datasets_names):\n",
    "    for j, algo in enumerate(algo_names):\n",
    "        df = embeddings_procustes_d_eval[(embeddings_procustes_d_eval[\"dataset\"] == dataset) & (embeddings_procustes_d_eval[\"algo\"] == algo)]\n",
    "        ax.bar(i + j * barwidth, df[\"mean\"], yerr=df[\"std\"], width=barwidth, label=f\"{algo}\", color = COLORS_4(j))\n",
    "ax.set_xticks(np.arange(n_datasets) + barwidth * (n_algos - 1) / 2)\n",
    "ax.set_xticklabels([w.upper() for w in datasets_names], fontsize = 12, fontweight='bold')\n",
    "ax.set_ylabel(\"Procustes distance\", fontsize = 12, fontweight='bold')\n",
    "ax.legend(handles=legend_patches, loc=\"upper left\", bbox_to_anchor=(1, 1))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dr_eval",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
